"""
    Dataset for distiller
    Author: Heng-Jui Chang (https://github.com/vectominist)
"""

import os
import random
import numpy as np
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import Dataset
import torchaudio

HALF_BATCHSIZE_TIME = 99999


class WaveDataset(Dataset):
    """Waveform dataset for Disiller"""

    def __init__(
        self,
        task_config,
        bucket_size,
        file_path,
        sets,
        max_timestep=0,
        libri_root=None,
        **kwargs
    ):
        super().__init__()

        self.task_config = task_config
        self.libri_root = libri_root
        self.sample_length = task_config["sequence_length"]
        if self.sample_length > 0:
            print(
                "[Dataset] - Sampling random segments for training, sample length:",
                self.sample_length,
            )

        # Read file
        self.root = file_path
        tables = [pd.read_csv(os.path.join(file_path, s + ".csv")) for s in sets]
        self.table = pd.concat(tables, ignore_index=True).sort_values(
            by=["length"], ascending=False
        )
        print("[Dataset] - Training data from these sets:", str(sets))

        # Drop seqs that are too long
        if max_timestep > 0:
            self.table = self.table[self.table.length < max_timestep]
        # Drop seqs that are too short
        if max_timestep < 0:
            self.table = self.table[self.table.length > (-1 * max_timestep)]

        X = self.table["file_path"].tolist()
        X_lens = self.table["length"].tolist()
        self.num_samples = len(X)
        print("[Dataset] - Number of individual training instances:", self.num_samples)

        # Use bucketing to allow different batch size at run time
        self.X = []
        batch_x, batch_len = [], []

        for x, x_len in zip(X, X_lens):
            batch_x.append(x)
            batch_len.append(x_len)

            # Fill in batch_x until batch is full
            if len(batch_x) == bucket_size:
                # Half the batch size if seq too long
                if (
                    (bucket_size >= 2)
                    and (max(batch_len) > HALF_BATCHSIZE_TIME)
                    and self.sample_length == 0
                ):
                    self.X.append(batch_x[: bucket_size // 2])
                    self.X.append(batch_x[bucket_size // 2 :])
                else:
                    self.X.append(batch_x)
                batch_x, batch_len = [], []

        # Gather the last batch
        if len(batch_x) > 1:
            self.X.append(batch_x)

    def _sample(self, x):
        if self.sample_length <= 0:
            return x
        if len(x) < self.sample_length:
            return x
        idx = random.randint(0, len(x) - self.sample_length)
        return x[idx : idx + self.sample_length]

    def __len__(self):
        return len(self.X)

    def collate_fn(self, items):
        items = items[0]  # hack bucketing
        assert (
            len(items) == 4
        ), "__getitem__ should return (wave_input, wave_orig, wave_len, pad_mask)"
        return items


class OnlineWaveDataset(WaveDataset):
    """Online waveform dataset"""

    def __init__(
        self,
        task_config,
        bucket_size,
        file_path,
        sets,
        max_timestep=0,
        libri_root=None,
        target_level=-25,
        **kwargs
    ):
        super().__init__(
            task_config,
            bucket_size,
            file_path,
            sets,
            max_timestep,
            libri_root,
            **kwargs
        )
        self.target_level = target_level

    def _load_feat(self, feat_path):
        if self.libri_root is None:
            return torch.FloatTensor(np.load(os.path.join(self.root, feat_path)))
        wav, _ = torchaudio.load(os.path.join(self.libri_root, feat_path))
        return wav.squeeze()  # (seq_len)

    def __getitem__(self, index):
        # Load acoustic feature and pad
        x_batch = [self._sample(self._load_feat(x_file)) for x_file in self.X[index]]
        x_lens = [len(x) for x in x_batch]
        x_lens = torch.LongTensor(x_lens)
        x_pad_batch = pad_sequence(x_batch, batch_first=True)

        pad_mask = torch.ones(x_pad_batch.shape)  # (batch_size, seq_len)
        # zero vectors for padding dimension
        for idx in range(x_pad_batch.shape[0]):
            pad_mask[idx, x_lens[idx] :] = 0

        return [x_pad_batch, x_batch, x_lens, pad_mask]
